1   /*
2    * Copyright (C) 2011 The Guava Authors
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    * http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package com.google.common.math;
18  
19  import static com.google.common.base.Preconditions.checkArgument;
20  import static com.google.common.base.Preconditions.checkNotNull;
21  import static com.google.common.math.MathPreconditions.checkNonNegative;
22  import static com.google.common.math.MathPreconditions.checkPositive;
23  import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
24  import static java.math.RoundingMode.CEILING;
25  import static java.math.RoundingMode.FLOOR;
26  import static java.math.RoundingMode.HALF_EVEN;
27  
28  import com.google.common.annotations.GwtCompatible;
29  import com.google.common.annotations.GwtIncompatible;
30  import com.google.common.annotations.VisibleForTesting;
31  
32  import java.math.BigDecimal;
33  import java.math.BigInteger;
34  import java.math.RoundingMode;
35  import java.util.ArrayList;
36  import java.util.List;
37  
38  /**
39   * A class for arithmetic on values of type {@code BigInteger}.
40   *
41   * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
42   * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
43   *
44   * <p>Similar functionality for {@code int} and for {@code long} can be found in
45   * {@link IntMath} and {@link LongMath} respectively.
46   *
47   * @author Louis Wasserman
48   * @since 11.0
49   */
50  @GwtCompatible(emulated = true)
51  public final class BigIntegerMath {
52    /**
53     * Returns {@code true} if {@code x} represents a power of two.
54     */
55    public static boolean isPowerOfTwo(BigInteger x) {
56      checkNotNull(x);
57      return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
58    }
59  
60    /**
61     * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
62     *
63     * @throws IllegalArgumentException if {@code x <= 0}
64     * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
65     *         is not a power of two
66     */
67    @SuppressWarnings("fallthrough")
68    // TODO(kevinb): remove after this warning is disabled globally
69    public static int log2(BigInteger x, RoundingMode mode) {
70      checkPositive("x", checkNotNull(x));
71      int logFloor = x.bitLength() - 1;
72      switch (mode) {
73        case UNNECESSARY:
74          checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
75        case DOWN:
76        case FLOOR:
77          return logFloor;
78  
79        case UP:
80        case CEILING:
81          return isPowerOfTwo(x) ? logFloor : logFloor + 1;
82  
83        case HALF_DOWN:
84        case HALF_UP:
85        case HALF_EVEN:
86          if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
87            BigInteger halfPower = SQRT2_PRECOMPUTED_BITS.shiftRight(
88                SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
89            if (x.compareTo(halfPower) <= 0) {
90              return logFloor;
91            } else {
92              return logFloor + 1;
93            }
94          }
95          /*
96           * Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
97           *
98           * To determine which side of logFloor.5 the logarithm is, we compare x^2 to 2^(2 *
99           * logFloor + 1).
100          */
101         BigInteger x2 = x.pow(2);
102         int logX2Floor = x2.bitLength() - 1;
103         return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
104 
105       default:
106         throw new AssertionError();
107     }
108   }
109 
110   /*
111    * The maximum number of bits in a square root for which we'll precompute an explicit half power
112    * of two. This can be any value, but higher values incur more class load time and linearly
113    * increasing memory consumption.
114    */
115   @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
116 
117   @VisibleForTesting static final BigInteger SQRT2_PRECOMPUTED_BITS =
118       new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
119 
120   /**
121    * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
122    *
123    * @throws IllegalArgumentException if {@code x <= 0}
124    * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
125    *         is not a power of ten
126    */
127   @GwtIncompatible("TODO")
128   @SuppressWarnings("fallthrough")
129   public static int log10(BigInteger x, RoundingMode mode) {
130     checkPositive("x", x);
131     if (fitsInLong(x)) {
132       return LongMath.log10(x.longValue(), mode);
133     }
134 
135     int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
136     BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
137     int approxCmp = approxPow.compareTo(x);
138 
139     /*
140      * We adjust approxLog10 and approxPow until they're equal to floor(log10(x)) and
141      * 10^floor(log10(x)).
142      */
143 
144     if (approxCmp > 0) {
145       /*
146        * The code is written so that even completely incorrect approximations will still yield the
147        * correct answer eventually, but in practice this branch should almost never be entered,
148        * and even then the loop should not run more than once.
149        */
150       do {
151         approxLog10--;
152         approxPow = approxPow.divide(BigInteger.TEN);
153         approxCmp = approxPow.compareTo(x);
154       } while (approxCmp > 0);
155     } else {
156       BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
157       int nextCmp = nextPow.compareTo(x);
158       while (nextCmp <= 0) {
159         approxLog10++;
160         approxPow = nextPow;
161         approxCmp = nextCmp;
162         nextPow = BigInteger.TEN.multiply(approxPow);
163         nextCmp = nextPow.compareTo(x);
164       }
165     }
166 
167     int floorLog = approxLog10;
168     BigInteger floorPow = approxPow;
169     int floorCmp = approxCmp;
170 
171     switch (mode) {
172       case UNNECESSARY:
173         checkRoundingUnnecessary(floorCmp == 0);
174         // fall through
175       case FLOOR:
176       case DOWN:
177         return floorLog;
178 
179       case CEILING:
180       case UP:
181         return floorPow.equals(x) ? floorLog : floorLog + 1;
182 
183       case HALF_DOWN:
184       case HALF_UP:
185       case HALF_EVEN:
186         // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
187         BigInteger x2 = x.pow(2);
188         BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
189         return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
190       default:
191         throw new AssertionError();
192     }
193   }
194 
195   private static final double LN_10 = Math.log(10);
196   private static final double LN_2 = Math.log(2);
197 
198   /**
199    * Returns the square root of {@code x}, rounded with the specified rounding mode.
200    *
201    * @throws IllegalArgumentException if {@code x < 0}
202    * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
203    *         {@code sqrt(x)} is not an integer
204    */
205   @GwtIncompatible("TODO")
206   @SuppressWarnings("fallthrough")
207   public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
208     checkNonNegative("x", x);
209     if (fitsInLong(x)) {
210       return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
211     }
212     BigInteger sqrtFloor = sqrtFloor(x);
213     switch (mode) {
214       case UNNECESSARY:
215         checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
216       case FLOOR:
217       case DOWN:
218         return sqrtFloor;
219       case CEILING:
220       case UP:
221         int sqrtFloorInt = sqrtFloor.intValue();
222         boolean sqrtFloorIsExact =
223             (sqrtFloorInt * sqrtFloorInt == x.intValue()) // fast check mod 2^32
224             && sqrtFloor.pow(2).equals(x); // slow exact check
225         return sqrtFloorIsExact ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
226       case HALF_DOWN:
227       case HALF_UP:
228       case HALF_EVEN:
229         BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
230         /*
231          * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both
232          * x and halfSquare are integers, this is equivalent to testing whether or not x <=
233          * halfSquare.
234          */
235         return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
236       default:
237         throw new AssertionError();
238     }
239   }
240 
241   @GwtIncompatible("TODO")
242   private static BigInteger sqrtFloor(BigInteger x) {
243     /*
244      * Adapted from Hacker's Delight, Figure 11-1.
245      *
246      * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
247      * then we can get a double approximation of the square root. Then, we iteratively improve this
248      * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
249      * This iteration has the following two properties:
250      *
251      * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
252      * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
253      * and the arithmetic mean is always higher than the geometric mean.
254      *
255      * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
256      * with each iteration, so this algorithm takes O(log(digits)) iterations.
257      *
258      * We start out with a double-precision approximation, which may be higher or lower than the
259      * true value. Therefore, we perform at least one Newton iteration to get a guess that's
260      * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
261      */
262     BigInteger sqrt0;
263     int log2 = log2(x, FLOOR);
264     if (log2 < Double.MAX_EXPONENT) {
265       sqrt0 = sqrtApproxWithDoubles(x);
266     } else {
267       int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
268       /*
269        * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
270        * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
271        */
272       sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
273     }
274     BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
275     if (sqrt0.equals(sqrt1)) {
276       return sqrt0;
277     }
278     do {
279       sqrt0 = sqrt1;
280       sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
281     } while (sqrt1.compareTo(sqrt0) < 0);
282     return sqrt0;
283   }
284 
285   @GwtIncompatible("TODO")
286   private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
287     return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
288   }
289 
290   /**
291    * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
292    * {@code RoundingMode}.
293    *
294    * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
295    *         is not an integer multiple of {@code b}
296    */
297   @GwtIncompatible("TODO")
298   public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode) {
299     BigDecimal pDec = new BigDecimal(p);
300     BigDecimal qDec = new BigDecimal(q);
301     return pDec.divide(qDec, 0, mode).toBigIntegerExact();
302   }
303 
304   /**
305    * Returns {@code n!}, that is, the product of the first {@code n} positive
306    * integers, or {@code 1} if {@code n == 0}.
307    *
308    * <p><b>Warning:</b> the result takes <i>O(n log n)</i> space, so use cautiously.
309    *
310    * <p>This uses an efficient binary recursive algorithm to compute the factorial
311    * with balanced multiplies.  It also removes all the 2s from the intermediate
312    * products (shifting them back in at the end).
313    *
314    * @throws IllegalArgumentException if {@code n < 0}
315    */
316   public static BigInteger factorial(int n) {
317     checkNonNegative("n", n);
318 
319     // If the factorial is small enough, just use LongMath to do it.
320     if (n < LongMath.factorials.length) {
321       return BigInteger.valueOf(LongMath.factorials[n]);
322     }
323 
324     // Pre-allocate space for our list of intermediate BigIntegers.
325     int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
326     ArrayList<BigInteger> bignums = new ArrayList<BigInteger>(approxSize);
327 
328     // Start from the pre-computed maximum long factorial.
329     int startingNumber = LongMath.factorials.length;
330     long product = LongMath.factorials[startingNumber - 1];
331     // Strip off 2s from this value.
332     int shift = Long.numberOfTrailingZeros(product);
333     product >>= shift;
334 
335     // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
336     int productBits = LongMath.log2(product, FLOOR) + 1;
337     int bits = LongMath.log2(startingNumber, FLOOR) + 1;
338     // Check for the next power of two boundary, to save us a CLZ operation.
339     int nextPowerOfTwo = 1 << (bits - 1);
340 
341     // Iteratively multiply the longs as big as they can go.
342     for (long num = startingNumber; num <= n; num++) {
343       // Check to see if the floor(log2(num)) + 1 has changed.
344       if ((num & nextPowerOfTwo) != 0) {
345         nextPowerOfTwo <<= 1;
346         bits++;
347       }
348       // Get rid of the 2s in num.
349       int tz = Long.numberOfTrailingZeros(num);
350       long normalizedNum = num >> tz;
351       shift += tz;
352       // Adjust floor(log2(num)) + 1.
353       int normalizedBits = bits - tz;
354       // If it won't fit in a long, then we store off the intermediate product.
355       if (normalizedBits + productBits >= Long.SIZE) {
356         bignums.add(BigInteger.valueOf(product));
357         product = 1;
358         productBits = 0;
359       }
360       product *= normalizedNum;
361       productBits = LongMath.log2(product, FLOOR) + 1;
362     }
363     // Check for leftovers.
364     if (product > 1) {
365       bignums.add(BigInteger.valueOf(product));
366     }
367     // Efficiently multiply all the intermediate products together.
368     return listProduct(bignums).shiftLeft(shift);
369   }
370 
371   static BigInteger listProduct(List<BigInteger> nums) {
372     return listProduct(nums, 0, nums.size());
373   }
374 
375   static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
376     switch (end - start) {
377       case 0:
378         return BigInteger.ONE;
379       case 1:
380         return nums.get(start);
381       case 2:
382         return nums.get(start).multiply(nums.get(start + 1));
383       case 3:
384         return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
385       default:
386         // Otherwise, split the list in half and recursively do this.
387         int m = (end + start) >>> 1;
388         return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
389     }
390   }
391 
392  /**
393    * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
394    * {@code k}, that is, {@code n! / (k! (n - k)!)}.
395    *
396    * <p><b>Warning:</b> the result can take as much as <i>O(k log n)</i> space.
397    *
398    * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
399    */
400   public static BigInteger binomial(int n, int k) {
401     checkNonNegative("n", n);
402     checkNonNegative("k", k);
403     checkArgument(k <= n, "k (%s) > n (%s)", k, n);
404     if (k > (n >> 1)) {
405       k = n - k;
406     }
407     if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
408       return BigInteger.valueOf(LongMath.binomial(n, k));
409     }
410 
411     BigInteger accum = BigInteger.ONE;
412 
413     long numeratorAccum = n;
414     long denominatorAccum = 1;
415 
416     int bits = LongMath.log2(n, RoundingMode.CEILING);
417 
418     int numeratorBits = bits;
419 
420     for (int i = 1; i < k; i++) {
421       int p = n - i;
422       int q = i + 1;
423 
424       // log2(p) >= bits - 1, because p >= n/2
425 
426       if (numeratorBits + bits >= Long.SIZE - 1) {
427         // The numerator is as big as it can get without risking overflow.
428         // Multiply numeratorAccum / denominatorAccum into accum.
429         accum = accum
430             .multiply(BigInteger.valueOf(numeratorAccum))
431             .divide(BigInteger.valueOf(denominatorAccum));
432         numeratorAccum = p;
433         denominatorAccum = q;
434         numeratorBits = bits;
435       } else {
436         // We can definitely multiply into the long accumulators without overflowing them.
437         numeratorAccum *= p;
438         denominatorAccum *= q;
439         numeratorBits += bits;
440       }
441     }
442     return accum
443         .multiply(BigInteger.valueOf(numeratorAccum))
444         .divide(BigInteger.valueOf(denominatorAccum));
445   }
446 
447   // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
448   @GwtIncompatible("TODO")
449   static boolean fitsInLong(BigInteger x) {
450     return x.bitLength() <= Long.SIZE - 1;
451   }
452 
453   private BigIntegerMath() {}
454 }